import json

import torch
import math
import numpy as np
import torch.optim as optim

from generic.data_util import ICEHOCKEY_GAME_FEATURES, ICEHOCKEY_ACTIONS, get_nf_input
from layers.nf_nn_bak import MAF, MADE


def build_maf(agent):
    if "maf" in agent.maf_flow_type.lower():
        model = MAF(dim=agent.maf_num_inputs,
                    n_layers=agent.maf_num_blocks,
                    hidden_dims=[agent.maf_num_hidden],
                    device=agent.device)
    elif "made" in agent.maf_flow_type.lower():
        model = MADE(n_in=agent.maf_num_inputs,
                     hidden_dims=[agent.maf_num_hidden],
                     device=agent.device,
                     random_order=False,
                     seed=290713,
                     gaussian=True)
    else:
        raise ValueError("Unknown model type {0}".format(agent.maf_flow_type))

    model.to(agent.device)
    agent.maf_model = model
    agent.maf_optim = optim.Adam(model.parameters(), lr=agent.maf_lr, weight_decay=1e-6)


def update_maf(agent, batch, sanity_check_msg):
    tgt_data = get_nf_input(agent=agent,
                            state_action=batch.state_action,
                            trace=batch.trace,
                            apply_history=agent.maf_apply_history,
                            sanity_check_msg=sanity_check_msg)

    if "maf" in agent.maf_flow_type.lower():
        u, log_det = agent.maf_model.forward(tgt_data.float())
        negloglik_loss = 0.5 * (u ** 2).sum(dim=1)
        negloglik_loss += 0.5 * tgt_data.shape[1] * np.log(2 * math.pi)
        loss = negloglik_loss - log_det
        loss = torch.mean(loss)
        if torch.isnan(loss):
            print('skip the nan loss.')
            return None, None
        logprob = -torch.mean(negloglik_loss)
        agent.maf_optim.zero_grad()
        loss.backward()
        agent.maf_optim.step()
    elif "made" in agent.maf_flow_type.lower():
        out = agent.maf_model.forward(tgt_data.float())
        mu, logp = torch.chunk(out, 2, dim=1)
        u = (tgt_data - mu) * torch.exp(0.5 * logp)
        negloglik_loss = 0.5 * (u ** 2).sum(dim=1)
        negloglik_loss += 0.5 * tgt_data.shape[1] * np.log(2 * math.pi)
        loss = negloglik_loss - 0.5 * torch.sum(logp, dim=1)
        loss = torch.mean(loss)
        logprob = -torch.mean(negloglik_loss)
        agent.maf_optim.zero_grad()
        loss.backward()
        agent.maf_optim.step()
    else:
        raise ValueError("Unknown model type {0}".format(agent.maf_flow_type))

    return loss, logprob

    # pbar.update(data.size(0))
    # pbar.set_description('Train, Log likelihood in nats: {:.6f}'.format(
    #     -train_loss / (batch_idx + 1)))
    #
    # writer.add_scalar('training/loss', loss.item(), global_step)
    # global_step += 1


def validate_maf(agent, batch, sanity_check_msg):
    tgt_data = get_nf_input(agent=agent,
                            state_action=batch.state_action,
                            trace=batch.trace,
                            apply_history=agent.maf_apply_history,
                            sanity_check_msg=sanity_check_msg)
    # cond_data = torch.stack(cond_data)  # s_t-1, a_t-1
    with torch.no_grad():
        if "maf" in agent.maf_flow_type.lower():
            u, log_det = agent.maf_model.forward(tgt_data.float())
            negloglik_loss = 0.5 * (u ** 2).sum(dim=1)
            negloglik_loss += 0.5 * tgt_data.shape[1] * np.log(2 * math.pi)
            loss = negloglik_loss - log_det
        elif "made" in agent.maf_flow_type.lower():
            out = agent.maf_model.forward(tgt_data.float())
            mu, logp = torch.chunk(out, 2, dim=1)
            u = (tgt_data - mu) * torch.exp(0.5 * logp)
            negloglik_loss = 0.5 * (u ** 2).sum(dim=1)
            negloglik_loss += 0.5 * tgt_data.shape[1] * np.log(2 * math.pi)
            loss = negloglik_loss - 0.5 * torch.sum(logp, dim=1)
        else:
            raise ValueError("Unknown model type {0}".format(agent.maf_flow_type))
    return loss.detach().cpu().numpy(), -negloglik_loss.detach().cpu().numpy()
